{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "56cb76dc",
   "metadata": {},
   "source": [
    "# MultiBERTs vs. Original BERT\n",
    "\n",
    "Here, we'll compare the MultiBERTs models run for 2M steps with the single previously-released `bert-base-uncased` model, as described in **Appendix E.2** of [the paper](https://openreview.net/pdf?id=K0E_F0gFDgA). Our analysis will be unpaired with respect to seeds, but we'll still sample jointly over _examples_ in the evaluation set and report confidence intervals as described in Section 3 of the paper.\n",
    "\n",
    "We'll use SQuAD 2.0 here, but the code below can easily be modified to handle other tasks."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "599dbb4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "import re\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "from tqdm.notebook import tqdm  # for progress indicator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "da6ae100",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current\n",
      "                                 Dload  Upload   Total   Spent    Left  Speed\n",
      "100 10547  100 10547    0     0  60614      0 --:--:-- --:--:-- --:--:-- 60614\n",
      "  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current\n",
      "                                 Dload  Upload   Total   Spent    Left  Speed\n",
      "100 4268k  100 4268k    0     0  24.0M      0 --:--:-- --:--:-- --:--:-- 23.9M\n",
      "  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current\n",
      "                                 Dload  Upload   Total   Spent    Left  Speed\n",
      "100 31869  100 31869    0     0   220k      0 --:--:-- --:--:-- --:--:--  220k\n"
     ]
    }
   ],
   "source": [
    "scratch_dir = \"/tmp/multiberts_squad\"\n",
    "if not os.path.isdir(scratch_dir): \n",
    "    os.mkdir(scratch_dir)\n",
    "    \n",
    "preds_root = \"https://storage.googleapis.com/multiberts/public/example-predictions/SQuAD\"\n",
    "# Fetch SQuAD eval script. Rename to allow module import, as this is invalid otherwise.\n",
    "!curl $preds_root/evaluate-v2.0.py -o $scratch_dir/evaluate_squad2.py\n",
    "# Fetch development set labels\n",
    "!curl -O $preds_root/dev-v2.0.json --output-dir $scratch_dir\n",
    "# Fetch predictions index file\n",
    "!curl -O $preds_root/index.tsv --output-dir $scratch_dir"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "c5497da9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dev-v2.0.json  evaluate_squad2.py  index.tsv  __pycache__  v2.0\r\n"
     ]
    }
   ],
   "source": [
    "!ls $scratch_dir"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "26f051b7",
   "metadata": {},
   "source": [
    "Load the run metadata. You can also just look through the directory, but this index file is convenient if (as we do here) you only want to download some of the files."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "10902bc1",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>file</th>\n",
       "      <th>task</th>\n",
       "      <th>pretrain_id</th>\n",
       "      <th>n_steps</th>\n",
       "      <th>lr</th>\n",
       "      <th>ft_seed</th>\n",
       "      <th>release</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>55</th>\n",
       "      <td>v2.0/release=multiberts,pretrain_id=0,n_steps=...</td>\n",
       "      <td>v2.0</td>\n",
       "      <td>0</td>\n",
       "      <td>2M</td>\n",
       "      <td>0.00005</td>\n",
       "      <td>0</td>\n",
       "      <td>multiberts</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>56</th>\n",
       "      <td>v2.0/release=multiberts,pretrain_id=0,n_steps=...</td>\n",
       "      <td>v2.0</td>\n",
       "      <td>0</td>\n",
       "      <td>2M</td>\n",
       "      <td>0.00005</td>\n",
       "      <td>1</td>\n",
       "      <td>multiberts</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>57</th>\n",
       "      <td>v2.0/release=multiberts,pretrain_id=0,n_steps=...</td>\n",
       "      <td>v2.0</td>\n",
       "      <td>0</td>\n",
       "      <td>2M</td>\n",
       "      <td>0.00005</td>\n",
       "      <td>2</td>\n",
       "      <td>multiberts</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>58</th>\n",
       "      <td>v2.0/release=multiberts,pretrain_id=0,n_steps=...</td>\n",
       "      <td>v2.0</td>\n",
       "      <td>0</td>\n",
       "      <td>2M</td>\n",
       "      <td>0.00005</td>\n",
       "      <td>3</td>\n",
       "      <td>multiberts</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>59</th>\n",
       "      <td>v2.0/release=multiberts,pretrain_id=0,n_steps=...</td>\n",
       "      <td>v2.0</td>\n",
       "      <td>0</td>\n",
       "      <td>2M</td>\n",
       "      <td>0.00005</td>\n",
       "      <td>4</td>\n",
       "      <td>multiberts</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>305</th>\n",
       "      <td>v2.0/release=public,pretrain_id=0,n_steps=0,lr...</td>\n",
       "      <td>v2.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0.00005</td>\n",
       "      <td>0</td>\n",
       "      <td>public</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>306</th>\n",
       "      <td>v2.0/release=public,pretrain_id=0,n_steps=0,lr...</td>\n",
       "      <td>v2.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0.00005</td>\n",
       "      <td>1</td>\n",
       "      <td>public</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>307</th>\n",
       "      <td>v2.0/release=public,pretrain_id=0,n_steps=0,lr...</td>\n",
       "      <td>v2.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0.00005</td>\n",
       "      <td>2</td>\n",
       "      <td>public</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>308</th>\n",
       "      <td>v2.0/release=public,pretrain_id=0,n_steps=0,lr...</td>\n",
       "      <td>v2.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0.00005</td>\n",
       "      <td>3</td>\n",
       "      <td>public</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>309</th>\n",
       "      <td>v2.0/release=public,pretrain_id=0,n_steps=0,lr...</td>\n",
       "      <td>v2.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0.00005</td>\n",
       "      <td>4</td>\n",
       "      <td>public</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>130 rows × 7 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                  file  task  pretrain_id  \\\n",
       "55   v2.0/release=multiberts,pretrain_id=0,n_steps=...  v2.0            0   \n",
       "56   v2.0/release=multiberts,pretrain_id=0,n_steps=...  v2.0            0   \n",
       "57   v2.0/release=multiberts,pretrain_id=0,n_steps=...  v2.0            0   \n",
       "58   v2.0/release=multiberts,pretrain_id=0,n_steps=...  v2.0            0   \n",
       "59   v2.0/release=multiberts,pretrain_id=0,n_steps=...  v2.0            0   \n",
       "..                                                 ...   ...          ...   \n",
       "305  v2.0/release=public,pretrain_id=0,n_steps=0,lr...  v2.0            0   \n",
       "306  v2.0/release=public,pretrain_id=0,n_steps=0,lr...  v2.0            0   \n",
       "307  v2.0/release=public,pretrain_id=0,n_steps=0,lr...  v2.0            0   \n",
       "308  v2.0/release=public,pretrain_id=0,n_steps=0,lr...  v2.0            0   \n",
       "309  v2.0/release=public,pretrain_id=0,n_steps=0,lr...  v2.0            0   \n",
       "\n",
       "    n_steps       lr  ft_seed     release  \n",
       "55       2M  0.00005        0  multiberts  \n",
       "56       2M  0.00005        1  multiberts  \n",
       "57       2M  0.00005        2  multiberts  \n",
       "58       2M  0.00005        3  multiberts  \n",
       "59       2M  0.00005        4  multiberts  \n",
       "..      ...      ...      ...         ...  \n",
       "305       0  0.00005        0      public  \n",
       "306       0  0.00005        1      public  \n",
       "307       0  0.00005        2      public  \n",
       "308       0  0.00005        3      public  \n",
       "309       0  0.00005        4      public  \n",
       "\n",
       "[130 rows x 7 columns]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "run_info = pd.read_csv(os.path.join(scratch_dir, 'index.tsv'), sep='\\t')\n",
    "# Filter to SQuAD 2.0 runs from either 2M MultiBERTs or the original BERT checkpoint (\"public\").\n",
    "mask = run_info.task == \"v2.0\"\n",
    "mask &= (run_info.n_steps == \"2M\") | (run_info.release == 'public')\n",
    "run_info = run_info[mask]\n",
    "run_info"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "1b00369b",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "379e318c7b3a423ea69f8f296d6c3989",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/130 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Download all prediction files\n",
    "for fname in tqdm(run_info.file):\n",
    "    !curl $preds_root/$fname -o $scratch_dir/$fname --create-dirs --silent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "a29f52b7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "'release=multiberts,pretrain_id=0,n_steps=2M,lr=5e-05,ft_seed=0.json'\r\n",
      "'release=multiberts,pretrain_id=0,n_steps=2M,lr=5e-05,ft_seed=1.json'\r\n",
      "'release=multiberts,pretrain_id=0,n_steps=2M,lr=5e-05,ft_seed=2.json'\r\n",
      "'release=multiberts,pretrain_id=0,n_steps=2M,lr=5e-05,ft_seed=3.json'\r\n",
      "'release=multiberts,pretrain_id=0,n_steps=2M,lr=5e-05,ft_seed=4.json'\r\n",
      "'release=multiberts,pretrain_id=10,n_steps=2M,lr=5e-05,ft_seed=0.json'\r\n",
      "'release=multiberts,pretrain_id=10,n_steps=2M,lr=5e-05,ft_seed=1.json'\r\n",
      "'release=multiberts,pretrain_id=10,n_steps=2M,lr=5e-05,ft_seed=2.json'\r\n",
      "'release=multiberts,pretrain_id=10,n_steps=2M,lr=5e-05,ft_seed=3.json'\r\n",
      "'release=multiberts,pretrain_id=10,n_steps=2M,lr=5e-05,ft_seed=4.json'\r\n",
      "'release=multiberts,pretrain_id=11,n_steps=2M,lr=5e-05,ft_seed=0.json'\r\n",
      "'release=multiberts,pretrain_id=11,n_steps=2M,lr=5e-05,ft_seed=1.json'\r\n",
      "'release=multiberts,pretrain_id=11,n_steps=2M,lr=5e-05,ft_seed=2.json'\r\n",
      "'release=multiberts,pretrain_id=11,n_steps=2M,lr=5e-05,ft_seed=3.json'\r\n",
      "'release=multiberts,pretrain_id=11,n_steps=2M,lr=5e-05,ft_seed=4.json'\r\n",
      "'release=multiberts,pretrain_id=12,n_steps=2M,lr=5e-05,ft_seed=0.json'\r\n",
      "'release=multiberts,pretrain_id=12,n_steps=2M,lr=5e-05,ft_seed=1.json'\r\n",
      "'release=multiberts,pretrain_id=12,n_steps=2M,lr=5e-05,ft_seed=2.json'\r\n",
      "'release=multiberts,pretrain_id=12,n_steps=2M,lr=5e-05,ft_seed=3.json'\r\n",
      "'release=multiberts,pretrain_id=12,n_steps=2M,lr=5e-05,ft_seed=4.json'\r\n",
      "'release=multiberts,pretrain_id=13,n_steps=2M,lr=5e-05,ft_seed=0.json'\r\n",
      "'release=multiberts,pretrain_id=13,n_steps=2M,lr=5e-05,ft_seed=1.json'\r\n",
      "'release=multiberts,pretrain_id=13,n_steps=2M,lr=5e-05,ft_seed=2.json'\r\n",
      "'release=multiberts,pretrain_id=13,n_steps=2M,lr=5e-05,ft_seed=3.json'\r\n",
      "'release=multiberts,pretrain_id=13,n_steps=2M,lr=5e-05,ft_seed=4.json'\r\n",
      "'release=multiberts,pretrain_id=14,n_steps=2M,lr=5e-05,ft_seed=0.json'\r\n",
      "'release=multiberts,pretrain_id=14,n_steps=2M,lr=5e-05,ft_seed=1.json'\r\n",
      "'release=multiberts,pretrain_id=14,n_steps=2M,lr=5e-05,ft_seed=2.json'\r\n",
      "'release=multiberts,pretrain_id=14,n_steps=2M,lr=5e-05,ft_seed=3.json'\r\n",
      "'release=multiberts,pretrain_id=14,n_steps=2M,lr=5e-05,ft_seed=4.json'\r\n",
      "'release=multiberts,pretrain_id=15,n_steps=2M,lr=5e-05,ft_seed=0.json'\r\n",
      "'release=multiberts,pretrain_id=15,n_steps=2M,lr=5e-05,ft_seed=1.json'\r\n",
      "'release=multiberts,pretrain_id=15,n_steps=2M,lr=5e-05,ft_seed=2.json'\r\n",
      "'release=multiberts,pretrain_id=15,n_steps=2M,lr=5e-05,ft_seed=3.json'\r\n",
      "'release=multiberts,pretrain_id=15,n_steps=2M,lr=5e-05,ft_seed=4.json'\r\n",
      "'release=multiberts,pretrain_id=16,n_steps=2M,lr=5e-05,ft_seed=0.json'\r\n",
      "'release=multiberts,pretrain_id=16,n_steps=2M,lr=5e-05,ft_seed=1.json'\r\n",
      "'release=multiberts,pretrain_id=16,n_steps=2M,lr=5e-05,ft_seed=2.json'\r\n",
      "'release=multiberts,pretrain_id=16,n_steps=2M,lr=5e-05,ft_seed=3.json'\r\n",
      "'release=multiberts,pretrain_id=16,n_steps=2M,lr=5e-05,ft_seed=4.json'\r\n",
      "'release=multiberts,pretrain_id=17,n_steps=2M,lr=5e-05,ft_seed=0.json'\r\n",
      "'release=multiberts,pretrain_id=17,n_steps=2M,lr=5e-05,ft_seed=1.json'\r\n",
      "'release=multiberts,pretrain_id=17,n_steps=2M,lr=5e-05,ft_seed=2.json'\r\n",
      "'release=multiberts,pretrain_id=17,n_steps=2M,lr=5e-05,ft_seed=3.json'\r\n",
      "'release=multiberts,pretrain_id=17,n_steps=2M,lr=5e-05,ft_seed=4.json'\r\n",
      "'release=multiberts,pretrain_id=18,n_steps=2M,lr=5e-05,ft_seed=0.json'\r\n",
      "'release=multiberts,pretrain_id=18,n_steps=2M,lr=5e-05,ft_seed=1.json'\r\n",
      "'release=multiberts,pretrain_id=18,n_steps=2M,lr=5e-05,ft_seed=2.json'\r\n",
      "'release=multiberts,pretrain_id=18,n_steps=2M,lr=5e-05,ft_seed=3.json'\r\n",
      "'release=multiberts,pretrain_id=18,n_steps=2M,lr=5e-05,ft_seed=4.json'\r\n",
      "'release=multiberts,pretrain_id=19,n_steps=2M,lr=5e-05,ft_seed=0.json'\r\n",
      "'release=multiberts,pretrain_id=19,n_steps=2M,lr=5e-05,ft_seed=1.json'\r\n",
      "'release=multiberts,pretrain_id=19,n_steps=2M,lr=5e-05,ft_seed=2.json'\r\n",
      "'release=multiberts,pretrain_id=19,n_steps=2M,lr=5e-05,ft_seed=3.json'\r\n",
      "'release=multiberts,pretrain_id=19,n_steps=2M,lr=5e-05,ft_seed=4.json'\r\n",
      "'release=multiberts,pretrain_id=1,n_steps=2M,lr=5e-05,ft_seed=0.json'\r\n",
      "'release=multiberts,pretrain_id=1,n_steps=2M,lr=5e-05,ft_seed=1.json'\r\n",
      "'release=multiberts,pretrain_id=1,n_steps=2M,lr=5e-05,ft_seed=2.json'\r\n",
      "'release=multiberts,pretrain_id=1,n_steps=2M,lr=5e-05,ft_seed=3.json'\r\n",
      "'release=multiberts,pretrain_id=1,n_steps=2M,lr=5e-05,ft_seed=4.json'\r\n",
      "'release=multiberts,pretrain_id=20,n_steps=2M,lr=5e-05,ft_seed=0.json'\r\n",
      "'release=multiberts,pretrain_id=20,n_steps=2M,lr=5e-05,ft_seed=1.json'\r\n",
      "'release=multiberts,pretrain_id=20,n_steps=2M,lr=5e-05,ft_seed=2.json'\r\n",
      "'release=multiberts,pretrain_id=20,n_steps=2M,lr=5e-05,ft_seed=3.json'\r\n",
      "'release=multiberts,pretrain_id=20,n_steps=2M,lr=5e-05,ft_seed=4.json'\r\n",
      "'release=multiberts,pretrain_id=21,n_steps=2M,lr=5e-05,ft_seed=0.json'\r\n",
      "'release=multiberts,pretrain_id=21,n_steps=2M,lr=5e-05,ft_seed=1.json'\r\n",
      "'release=multiberts,pretrain_id=21,n_steps=2M,lr=5e-05,ft_seed=2.json'\r\n",
      "'release=multiberts,pretrain_id=21,n_steps=2M,lr=5e-05,ft_seed=3.json'\r\n",
      "'release=multiberts,pretrain_id=21,n_steps=2M,lr=5e-05,ft_seed=4.json'\r\n",
      "'release=multiberts,pretrain_id=22,n_steps=2M,lr=5e-05,ft_seed=0.json'\r\n",
      "'release=multiberts,pretrain_id=22,n_steps=2M,lr=5e-05,ft_seed=1.json'\r\n",
      "'release=multiberts,pretrain_id=22,n_steps=2M,lr=5e-05,ft_seed=2.json'\r\n",
      "'release=multiberts,pretrain_id=22,n_steps=2M,lr=5e-05,ft_seed=3.json'\r\n",
      "'release=multiberts,pretrain_id=22,n_steps=2M,lr=5e-05,ft_seed=4.json'\r\n",
      "'release=multiberts,pretrain_id=23,n_steps=2M,lr=5e-05,ft_seed=0.json'\r\n",
      "'release=multiberts,pretrain_id=23,n_steps=2M,lr=5e-05,ft_seed=1.json'\r\n",
      "'release=multiberts,pretrain_id=23,n_steps=2M,lr=5e-05,ft_seed=2.json'\r\n",
      "'release=multiberts,pretrain_id=23,n_steps=2M,lr=5e-05,ft_seed=3.json'\r\n",
      "'release=multiberts,pretrain_id=23,n_steps=2M,lr=5e-05,ft_seed=4.json'\r\n",
      "'release=multiberts,pretrain_id=24,n_steps=2M,lr=5e-05,ft_seed=0.json'\r\n",
      "'release=multiberts,pretrain_id=24,n_steps=2M,lr=5e-05,ft_seed=1.json'\r\n",
      "'release=multiberts,pretrain_id=24,n_steps=2M,lr=5e-05,ft_seed=2.json'\r\n",
      "'release=multiberts,pretrain_id=24,n_steps=2M,lr=5e-05,ft_seed=3.json'\r\n",
      "'release=multiberts,pretrain_id=24,n_steps=2M,lr=5e-05,ft_seed=4.json'\r\n",
      "'release=multiberts,pretrain_id=2,n_steps=2M,lr=5e-05,ft_seed=0.json'\r\n",
      "'release=multiberts,pretrain_id=2,n_steps=2M,lr=5e-05,ft_seed=1.json'\r\n",
      "'release=multiberts,pretrain_id=2,n_steps=2M,lr=5e-05,ft_seed=2.json'\r\n",
      "'release=multiberts,pretrain_id=2,n_steps=2M,lr=5e-05,ft_seed=3.json'\r\n",
      "'release=multiberts,pretrain_id=2,n_steps=2M,lr=5e-05,ft_seed=4.json'\r\n",
      "'release=multiberts,pretrain_id=3,n_steps=2M,lr=5e-05,ft_seed=0.json'\r\n",
      "'release=multiberts,pretrain_id=3,n_steps=2M,lr=5e-05,ft_seed=1.json'\r\n",
      "'release=multiberts,pretrain_id=3,n_steps=2M,lr=5e-05,ft_seed=2.json'\r\n",
      "'release=multiberts,pretrain_id=3,n_steps=2M,lr=5e-05,ft_seed=3.json'\r\n",
      "'release=multiberts,pretrain_id=3,n_steps=2M,lr=5e-05,ft_seed=4.json'\r\n",
      "'release=multiberts,pretrain_id=4,n_steps=2M,lr=5e-05,ft_seed=0.json'\r\n",
      "'release=multiberts,pretrain_id=4,n_steps=2M,lr=5e-05,ft_seed=1.json'\r\n",
      "'release=multiberts,pretrain_id=4,n_steps=2M,lr=5e-05,ft_seed=2.json'\r\n",
      "'release=multiberts,pretrain_id=4,n_steps=2M,lr=5e-05,ft_seed=3.json'\r\n",
      "'release=multiberts,pretrain_id=4,n_steps=2M,lr=5e-05,ft_seed=4.json'\r\n",
      "'release=multiberts,pretrain_id=5,n_steps=2M,lr=5e-05,ft_seed=0.json'\r\n",
      "'release=multiberts,pretrain_id=5,n_steps=2M,lr=5e-05,ft_seed=1.json'\r\n",
      "'release=multiberts,pretrain_id=5,n_steps=2M,lr=5e-05,ft_seed=2.json'\r\n",
      "'release=multiberts,pretrain_id=5,n_steps=2M,lr=5e-05,ft_seed=3.json'\r\n",
      "'release=multiberts,pretrain_id=5,n_steps=2M,lr=5e-05,ft_seed=4.json'\r\n",
      "'release=multiberts,pretrain_id=6,n_steps=2M,lr=5e-05,ft_seed=0.json'\r\n",
      "'release=multiberts,pretrain_id=6,n_steps=2M,lr=5e-05,ft_seed=1.json'\r\n",
      "'release=multiberts,pretrain_id=6,n_steps=2M,lr=5e-05,ft_seed=2.json'\r\n",
      "'release=multiberts,pretrain_id=6,n_steps=2M,lr=5e-05,ft_seed=3.json'\r\n",
      "'release=multiberts,pretrain_id=6,n_steps=2M,lr=5e-05,ft_seed=4.json'\r\n",
      "'release=multiberts,pretrain_id=7,n_steps=2M,lr=5e-05,ft_seed=0.json'\r\n",
      "'release=multiberts,pretrain_id=7,n_steps=2M,lr=5e-05,ft_seed=1.json'\r\n",
      "'release=multiberts,pretrain_id=7,n_steps=2M,lr=5e-05,ft_seed=2.json'\r\n",
      "'release=multiberts,pretrain_id=7,n_steps=2M,lr=5e-05,ft_seed=3.json'\r\n",
      "'release=multiberts,pretrain_id=7,n_steps=2M,lr=5e-05,ft_seed=4.json'\r\n",
      "'release=multiberts,pretrain_id=8,n_steps=2M,lr=5e-05,ft_seed=0.json'\r\n",
      "'release=multiberts,pretrain_id=8,n_steps=2M,lr=5e-05,ft_seed=1.json'\r\n",
      "'release=multiberts,pretrain_id=8,n_steps=2M,lr=5e-05,ft_seed=2.json'\r\n",
      "'release=multiberts,pretrain_id=8,n_steps=2M,lr=5e-05,ft_seed=3.json'\r\n",
      "'release=multiberts,pretrain_id=8,n_steps=2M,lr=5e-05,ft_seed=4.json'\r\n",
      "'release=multiberts,pretrain_id=9,n_steps=2M,lr=5e-05,ft_seed=0.json'\r\n",
      "'release=multiberts,pretrain_id=9,n_steps=2M,lr=5e-05,ft_seed=1.json'\r\n",
      "'release=multiberts,pretrain_id=9,n_steps=2M,lr=5e-05,ft_seed=2.json'\r\n",
      "'release=multiberts,pretrain_id=9,n_steps=2M,lr=5e-05,ft_seed=3.json'\r\n",
      "'release=multiberts,pretrain_id=9,n_steps=2M,lr=5e-05,ft_seed=4.json'\r\n",
      "'release=public,pretrain_id=0,n_steps=0,lr=5e-05,ft_seed=0.json'\r\n",
      "'release=public,pretrain_id=0,n_steps=0,lr=5e-05,ft_seed=1.json'\r\n",
      "'release=public,pretrain_id=0,n_steps=0,lr=5e-05,ft_seed=2.json'\r\n",
      "'release=public,pretrain_id=0,n_steps=0,lr=5e-05,ft_seed=3.json'\r\n",
      "'release=public,pretrain_id=0,n_steps=0,lr=5e-05,ft_seed=4.json'\r\n"
     ]
    }
   ],
   "source": [
    "!ls $scratch_dir/v2.0"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "86bc0010",
   "metadata": {},
   "source": [
    "Now we should have everything in our scratch directory, and can load individual predictions.\n",
    "\n",
    "SQuAD has a monolithic eval script that isn't easily compatible with a bootstrap procedure (among other things, it parses a lot of JSON, and you don't want to do that in the inner loop!). Ultimately, though, it relies on computing some point-wise scores (exact-match $\\in \\{0,1\\}$ and F1 $\\in [0,1]$) and averaging these across examples. For efficiency, we'll pre-compute these before running our bootstrap."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "be695204",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import the SQuAD 2.0 eval script; we'll use some functions from this below.\n",
    "import sys\n",
    "sys.path.append(scratch_dir)\n",
    "import evaluate_squad2 as squad_eval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "08970fbc",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Load dataset\n",
    "with open(os.path.join(scratch_dir, 'dev-v2.0.json')) as fd:\n",
    "    dataset = json.load(fd)['data']"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "901edb23",
   "metadata": {},
   "source": [
    "The official script supports thresholding for no-answer, but the default settings ignore this and treat only predictions of emptystring (`\"\"`) as no-answer. So, we can score on `exact_raw` and `f1_raw` directly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "131aa7d1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "30c305d44458404bb03f0b61098eb402",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/130 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "exact_scores = {}  # filename -> qid -> score\n",
    "f1_scores = {}     # filename -> qid -> score\n",
    "for fname in tqdm(run_info.file):\n",
    "    with open(os.path.join(scratch_dir, fname)) as fd:\n",
    "        preds = json.load(fd)\n",
    "    \n",
    "    exact_raw, f1_raw = squad_eval.get_raw_scores(dataset, preds)\n",
    "    exact_scores[fname] = exact_raw\n",
    "    f1_scores[fname] = f1_raw\n",
    "    \n",
    "def dict_of_dicts_to_matrix(dd):\n",
    "    \"\"\"Convert a scores to a dense matrix.\n",
    "    \n",
    "    Outer keys assumed to be rows, inner keys are columns (e.g. example IDs).\n",
    "    Uses pandas to ensure that different rows are correctly aligned.\n",
    "    \n",
    "    Args:\n",
    "      dd: map of row -> column -> value\n",
    "      \n",
    "    Returns:\n",
    "      np.ndarray of shape [num_rows, num_columns]\n",
    "    \"\"\"\n",
    "    # Use pandas to ensure keys are correctly aligned.\n",
    "    df = pd.DataFrame(dd).transpose()\n",
    "    return df.values\n",
    "\n",
    "exact_scores = dict_of_dicts_to_matrix(exact_scores)\n",
    "f1_scores = dict_of_dicts_to_matrix(f1_scores)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "d567e034",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(130, 11873)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "exact_scores.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c9eb0e3c",
   "metadata": {},
   "source": [
    "## Run multibootstrap\n",
    "\n",
    "base (`L`) is the original BERT checkpoint, expt (`L'`) is MultiBERTs with 2M steps. Since we pre-computed the pointwise exact match and F1 scores for each run and each example, we can just pass dummy labels and use a simple average over predictions as our scoring function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "8cb71332",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Metric: exact\n",
      "Multibootstrap (unpaired) on 11873 examples\n",
      "  Base seeds (1): [0]\n",
      "  Expt seeds (25): [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e3f5b387bdc34f838302c4224500c534",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Bootstrap statistics from 1000 samples:\n",
      "  E[L]  = 0.722 with 95% CI of (0.715 to 0.729)\n",
      "  E[L'] = 0.747 with 95% CI of (0.741 to 0.754)\n",
      "  E[L'-L] = 0.0254 with 95% CI of (0.0211 to 0.0296); p-value = 0\n",
      "\n",
      "Metric: f1\n",
      "Multibootstrap (unpaired) on 11873 examples\n",
      "  Base seeds (1): [0]\n",
      "  Expt seeds (25): [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "703a47a7b7fb43bda2d73040efdc2646",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Bootstrap statistics from 1000 samples:\n",
      "  E[L]  = 0.754 with 95% CI of (0.748 to 0.762)\n",
      "  E[L'] = 0.778 with 95% CI of (0.773 to 0.785)\n",
      "  E[L'-L] = 0.0239 with 95% CI of (0.0191 to 0.0286); p-value = 0\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr th {\n",
       "        text-align: left;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th colspan=\"4\" halign=\"left\">exact</th>\n",
       "      <th colspan=\"4\" halign=\"left\">f1</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th>mean</th>\n",
       "      <th>low</th>\n",
       "      <th>high</th>\n",
       "      <th>p</th>\n",
       "      <th>mean</th>\n",
       "      <th>low</th>\n",
       "      <th>high</th>\n",
       "      <th>p</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>base</th>\n",
       "      <td>0.721923</td>\n",
       "      <td>0.714658</td>\n",
       "      <td>0.728527</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0.754427</td>\n",
       "      <td>0.748034</td>\n",
       "      <td>0.761508</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>expt</th>\n",
       "      <td>0.747291</td>\n",
       "      <td>0.740529</td>\n",
       "      <td>0.753553</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0.778307</td>\n",
       "      <td>0.772585</td>\n",
       "      <td>0.784535</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>delta</th>\n",
       "      <td>0.025368</td>\n",
       "      <td>0.021090</td>\n",
       "      <td>0.029650</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.023880</td>\n",
       "      <td>0.019124</td>\n",
       "      <td>0.028629</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "          exact                                 f1                         \n",
       "           mean       low      high    p      mean       low      high    p\n",
       "base   0.721923  0.714658  0.728527  NaN  0.754427  0.748034  0.761508  NaN\n",
       "expt   0.747291  0.740529  0.753553  NaN  0.778307  0.772585  0.784535  NaN\n",
       "delta  0.025368  0.021090  0.029650  0.0  0.023880  0.019124  0.028629  0.0"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import multibootstrap\n",
    "\n",
    "num_bootstrap_samples = 1000\n",
    "\n",
    "selected_runs = run_info.copy()\n",
    "selected_runs['seed'] = selected_runs['pretrain_id']\n",
    "selected_runs['intervention'] = (selected_runs['release'] == 'multiberts')\n",
    "\n",
    "# Dummy labels\n",
    "dummy_labels = np.zeros_like(exact_scores[0])  # [num_examples]\n",
    "score_fn = lambda y_true, y_pred: np.mean(y_pred)\n",
    "\n",
    "# Targets; run once for each.\n",
    "targets = {'exact': exact_scores, 'f1': f1_scores}\n",
    "\n",
    "stats = {}\n",
    "for name, preds in targets.items():\n",
    "    print(f\"Metric: {name:s}\")\n",
    "    samples = multibootstrap.multibootstrap(selected_runs, preds, dummy_labels, score_fn,\n",
    "                                            nboot=num_bootstrap_samples,\n",
    "                                            paired_seeds=False,\n",
    "                                            progress_indicator=tqdm)\n",
    "    stats[name] = multibootstrap.report_ci(samples, c=0.95)\n",
    "    print(\"\")\n",
    "\n",
    "pd.concat({k: pd.DataFrame(v) for k,v in stats.items()}).transpose()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7205b96b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
